import torch
from libauc.losses import PSH
from surrogate import squared_loss, squared_hinge_loss
import numpy as np
import torch.nn as nn

class pAUC_CVaR_loss(nn.Module):
    def __init__(self, pos_length, num_neg, threshold=1.0, beta=0.2, loss_type = 'sh', toppush=False):
        '''
        param
        pos_length: number of positive examples for the training data
        num_neg: number of negative samples for each mini-batch
        threshold: margin for basic AUC loss
        beta: FPR upper bound for pAUC used for SOPA
        eta: stepsize for CVaR regularization term
        loss type: basic AUC loss to apply.
        '''
        super(pAUC_CVaR_loss, self).__init__()
        self.beta = round(beta*num_neg)/num_neg
        if toppush == True:
          self.beta = 1/num_neg
        self.eta = 1.0
        self.num_neg = num_neg
        self.pos_length = pos_length
        self.lambda_pos = torch.tensor([0.0]*pos_length).view(-1, 1).cuda()
        self.threshold = threshold
        if loss_type == 'sh':
          self.Loss = squared_hinge_loss
        elif loss_type == 'sq':
          self.Loss = squared_loss
        print('The loss type is :', loss_type)
    
    def set_eta(self, eta):
        self.eta = eta
    def update_eta(self, decay_factor):
        self.eta = self.eta/decay_factor
    
    def forward(self, y_pred, y_true, index_p, index_n):
        f_ps = y_pred[y_true == 1].view(-1,1)
        f_ns = y_pred[y_true == 0].view(-1,1) 
        f_ps = f_ps.repeat(1,len(f_ns))
        f_ns = f_ns.repeat(1,len(f_ps))
        difference = f_ps - f_ns.transpose(0,1)
        
        loss = self.Loss(margin = self.threshold, t = difference) # before mean() operation.
        p = loss > self.lambda_pos[index_p]
        

        self.lambda_pos[index_p] = self.lambda_pos[index_p]-self.eta/self.pos_length*(1 - p.sum(dim=1, keepdim=True)/(self.beta*self.num_neg))
        p.detach_()
        loss = torch.mean(p * loss) / self.beta
        
        return loss




